/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2017-2019. All rights reserved.
 * Description: The description of class OpIRFuncFactory of op_builde
 */

// inc/framework
#include "framework/graph/core/infershape/op_ir_func_factory.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/debug/ge_log.h"

using namespace std;

namespace ge {
OpIRFuncFactory* OpIRFuncFactory::Instance()
{
    static OpIRFuncFactory instance;
    return &instance;
}

OpInferFunc OpIRFuncFactory::GetInferFunc(Node& node)
{
    NodeSpec& spec = node.ROLE(NodeSpec);
    const auto& iter = funcsMap_.find(spec.Type());
    if (iter != funcsMap_.end()) {
        OpIRFuncs opIRFuncs = iter->second;
        return opIRFuncs.inferFunc;
    }

    FMK_LOGE("get [op:%s type:%s] infershape func failed.", spec.Name().c_str(), spec.Type().c_str());
    return nullptr;
}

OpVerifyFunc OpIRFuncFactory::GetVerifyFunc(Node& node)
{
    NodeSpec& spec = node.ROLE(NodeSpec);
    const auto& iter = funcsMap_.find(spec.Type());
    if (iter != funcsMap_.end()) {
        OpIRFuncs opIRFuncs = iter->second;
        return opIRFuncs.verifyFunc;
    }

    return nullptr;
}

void OpIRFuncFactory::RegisterInferFunc(const std::string& type, OpInferFunc fun)
{
    std::map<std::string, OpIRFuncs>::const_iterator iter = funcsMap_.find(type);
    if (iter != funcsMap_.end()) {
        OpIRFuncs opIRFuncs = iter->second;
        if (opIRFuncs.inferFunc != nullptr) {
            FMK_LOGE("op type:%s already exist.", type.c_str());
            return;
        }
        opIRFuncs.inferFunc = fun;
        funcsMap_[type] = opIRFuncs;
        return;
    }
    OpIRFuncs opIRFuncs;
    opIRFuncs.verifyFunc = nullptr;
    opIRFuncs.inferFunc = fun;
    funcsMap_[type] = opIRFuncs;
}

void OpIRFuncFactory::RegisterVerifyFunc(const std::string& type, OpVerifyFunc fun)
{
    std::map<std::string, OpIRFuncs>::const_iterator iter = funcsMap_.find(type);
    if (iter != funcsMap_.end()) {
        OpIRFuncs opIRFuncs = iter->second;
        if (opIRFuncs.verifyFunc != nullptr) {
            FMK_LOGE("op type:%s already exist.", type.c_str());
            return;
        }
        opIRFuncs.verifyFunc = fun;
        funcsMap_[type] = opIRFuncs;
        return;
    }
    OpIRFuncs opIRFuncs;
    opIRFuncs.verifyFunc = fun;
    opIRFuncs.inferFunc = nullptr;
    funcsMap_[type] = opIRFuncs;
}

void OpIRFuncFactory::UnregisterOpIRFunc(const std::string& type)
{
    const auto& iter = funcsMap_.find(type);
    if (iter != funcsMap_.end()) {
        iter->second = {};
        funcsMap_.erase(iter);
    }
}

OpInferAndVerifyRegisterar::OpInferAndVerifyRegisterar(
    const std::string& type, OpInferFunc inferFunc, OpVerifyFunc verifyFunc)
{
    OpIRFuncFactory::Instance()->RegisterInferFunc(type, inferFunc);
    OpIRFuncFactory::Instance()->RegisterVerifyFunc(type, verifyFunc);
    type_ = type;
}

OpInferAndVerifyRegisterar::~OpInferAndVerifyRegisterar()
{
    OpIRFuncFactory::Instance()->UnregisterOpIRFunc(type_);
}
} // namespace ge
